
import org.apache.commons.lang3.ArrayUtils;
import org.apache.curator.framework.CuratorFramework;
import org.apache.curator.framework.imps.CuratorFrameworkState;
import org.apache.curator.framework.recipes.cache.PathChildrenCache;
import org.apache.zookeeper.CreateMode;
import org.apache.zookeeper.KeeperException;
import org.apache.zookeeper.data.Stat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;

/**
 * @Author lnk
 * @Date 2018/10/18
 */
public class ZookeeperSchedulerImpl extends AbstractScheduler {
    private static final Logger logger = LoggerFactory.getLogger(ZookeeperSchedulerImpl.class);

    // 存放所有任务的根节点
    private static final String ROOT_PATH = "/_scheduler";
    // 所有服务器的等级节点，在parentPath节点下
    private static final String LEVEL_PATH = ROOT_PATH + "/_level";
    // 当前集群中服务器的最高级别
    private int maxAliveLevel = Integer.MAX_VALUE;
    private CuratorFramework zkClient;
    private PathChildrenCache levelListenable;

    public ZookeeperSchedulerImpl(CuratorFramework zkClient) throws Exception {
        this.zkClient = Objects.requireNonNull(zkClient, "zkClient 不能为空");
        init();
    }

    private void init() throws Exception {
        if (zkClient.getState() == CuratorFrameworkState.LATENT) {
            zkClient.start();
        }

        Stat stat = zkClient.checkExists().forPath(ROOT_PATH);
        if (stat == null) {
            zkClient.create().creatingParentsIfNeeded().forPath(ROOT_PATH);
        }
    }

    @Override
    public boolean check(String id) {
        Long time = getLong(getNodePathById(id));
        return time == null || currentTimeMillis() > time;
    }

    @Override
    public boolean lock(String id, long timeout) {
        String nodePath = getNodePathById(id);
        try {
            ByteBuffer buffer = ByteBuffer.allocate(8);
            byte[] array = buffer.putLong(currentTimeMillis() + timeout).array();
            Stat stat = zkClient.checkExists().forPath(nodePath);
            if (stat != null) {
                if (!check(id)) {
                    return false;
                }

                zkClient.setData().withVersion(stat.getVersion()).forPath(nodePath, array);
            } else {
//                zkClient.create().withTtl(timeout).withMode(CreateMode.PERSISTENT_WITH_TTL).forPath(nodePath, array);
                zkClient.create().withMode(CreateMode.PERSISTENT).forPath(nodePath, array);
            }
            return true;
        } catch (Exception e) {
            if (!(e instanceof KeeperException.BadVersionException)) {
                logger.error("lock error", e);
            }
            return false;
        }
    }

    @Override
    public void relock(String id, long timeout) {
        String nodePath = getNodePathById(id);
        ByteBuffer buffer = ByteBuffer.allocate(8);
        byte[] array = buffer.putLong(currentTimeMillis() + timeout).array();
        try {
            zkClient.setData().forPath(nodePath, array);
        } catch (Exception e) {
            logger.error("relock error", e);
        }
    }

    /**
     * 此处建议修改为从指定的服务器获取时间，避免不同服务器时间差距太大，导致任务重复执行
     * 又或者服务器之间做时间同步，保证集群服务的系统时间一致
     *
     * @return
     */
    @Override
    public long currentTimeMillis() {
        return System.currentTimeMillis();
    }

    /**
     * 这里使用的zookeeper的监听功能，spring.scheduling.cluster.heartTime的配置可以长一点时间，如1天，但不能不配置
     */
    @Override
    public void keepAlive() {
        try {
            String levelNode = LEVEL_PATH.concat("/").concat(String.valueOf(getLevel()));
            Stat stat = zkClient.checkExists().forPath(levelNode);
            // 如果没有相同等级的服务器注册，就将本机服务器等级注册到zookeeper
            if (stat == null) {
                zkClient.create().creatingParentsIfNeeded().withMode(CreateMode.EPHEMERAL).forPath(levelNode);
            }
        } catch (Exception e) {
            logger.error("keepAlive error", e);
        } finally {
            if (levelListenable != null) {
                return;
            }
            PathChildrenCache pathChildrenCache = new PathChildrenCache(zkClient, LEVEL_PATH, false);
            pathChildrenCache.getListenable().addListener((client, event) -> {
                logger.info("change level {}", getLevel());
                maxAliveLevel = Integer.MAX_VALUE;
                getMaxAliveLevel();
            });
            try {
                pathChildrenCache.start();
                levelListenable = pathChildrenCache;
            } catch (Exception e) {
                logger.error("Listenable error", e);
            }
        }
    }

    @Override
    public int getMaxAliveLevel() {
        if (maxAliveLevel != Integer.MAX_VALUE) {
            return maxAliveLevel;
        }
        synchronized (this) {
            if (maxAliveLevel != Integer.MAX_VALUE) {
                return maxAliveLevel;
            }

            try {
                // 获取全部的服务器级别，找出最高的级别
                List<String> nodes = zkClient.getChildren().forPath(LEVEL_PATH);
                int[] levels = nodes.stream().mapToInt(s -> Integer.parseInt(s)).toArray();
                if (!ArrayUtils.contains(levels, getLevel())) {
                    // 如果列表中已经不存在本机的级别，就将级别注册到zookeeper中
                    keepAlive();
                    logger.debug("Registered zookeeper level:{}", getLevel());
                }
                Arrays.sort(levels);
                maxAliveLevel = levels[0];
                logger.info("update max alive level to {}", maxAliveLevel);
            } catch (Exception e) {
                logger.error("getMaxAliveLevel error", e);
            }
        }
        return super.getMaxAliveLevel();
    }

    private String getNodePathById(String id) {
        return ROOT_PATH.concat("/").concat(id);
    }

    private Long getLong(String nodePath) {
        try {
            Stat stat = zkClient.checkExists().forPath(nodePath);
            if (stat != null) {
                byte[] bytes = zkClient.getData().forPath(nodePath);
                return getByteBufferAndFlip(bytes).getLong();
            }
        } catch (Exception e) {
            logger.error("getLong error", e);
        }
        return null;
    }

    private ByteBuffer getByteBufferAndFlip(byte[] bytes) {
        ByteBuffer buffer = ByteBuffer.allocate(bytes.length).put(bytes);
        buffer.flip();
        return buffer;
    }
}
